#!/usr/bin/env python3

import serial
import argparse
import sys
import time
import pathlib
from collections import deque


def printDebugDevice(msg):
	if getattr(printDebugDevice, "enabled", False):
		print("Device debug:", msg, file=sys.stderr)

def printDebug(msg):
	if getattr(printDebug, "enabled", False):
		print("Debug:", msg, file=sys.stderr)

def printInfo(msg):
	print(msg, file=sys.stderr)

def printWarning(msg):
	print("WARNING:", msg, file=sys.stderr)

def printError(msg):
	print("ERROR:", msg, file=sys.stderr)

def crc8(data):
	P = 0x07
	crc = 0
	for d in data:
		tmp = crc ^ d
		for i in range(8):
			if tmp & 0x80:
				tmp = ((tmp << 1) & 0xFF) ^ P
			else:
				tmp = (tmp << 1) & 0xFF
		crc = tmp
	return crc ^ 0xFF

class SimplePWMError(Exception):
	pass

class SimplePWMMsg(object):
	PAYLOAD_SIZE	= 8
	SIZE		= 1 + 1 + 1 + PAYLOAD_SIZE + 1

	MSG_MAGIC		= 0xAA
	MSG_SYNCBYTE		= b"\x00"

	MSGID_NOP		= 0
	MSGID_ACK		= 1
	MSGID_NACK		= 2
	MSGID_PING		= 3
	MSGID_PONG		= 4
	MSGID_GET_CONTROL	= 5
	MSGID_CONTROL		= 6
	MSGID_GET_SETPOINTS	= 7
	MSGID_SETPOINTS		= 8
	MSGID_GET_BATVOLT	= 9
	MSGID_BATVOLT		= 10

	@staticmethod
	def _toLe16(value):
		return bytearray( (value & 0xFF, (value >> 8) & 0xFF, ) )

	@staticmethod
	def _fromLe16(data):
		return data[0] | (data[1] << 8)

	@classmethod
	def parse(cls, data):
		assert len(data) == cls.SIZE
		magic, msgId = data[0:2]
		payload = data[3:-1]
		crc = data[-1]
		if magic != cls.MSG_MAGIC:
			printError("Received corrupted message: Magic byte mismatch.")
		elif crc8(data[:-1]) != crc:
			printError("Received corrupted message: CRC mismatch.")
		else:
			if msgId == cls.MSGID_NOP:
				return SimplePWMMsg_Nop._parse(payload)
			elif msgId == cls.MSGID_ACK:
				return SimplePWMMsg_Ack._parse(payload)
			elif msgId == cls.MSGID_NACK:
				return SimplePWMMsg_Nack._parse(payload)
			elif msgId == cls.MSGID_PING:
				return SimplePWMMsg_Ping._parse(payload)
			elif msgId == cls.MSGID_PONG:
				return SimplePWMMsg_Pong._parse(payload)
			elif msgId == cls.MSGID_GET_CONTROL:
				return SimplePWMMsg_GetControl._parse(payload)
			elif msgId == cls.MSGID_CONTROL:
				return SimplePWMMsg_Control._parse(payload)
			elif msgId == cls.MSGID_GET_SETPOINTS:
				return SimplePWMMsg_SetSetpoints._parse(payload)
			elif msgId == cls.MSGID_SETPOINTS:
				return SimplePWMMsg_Setpoints._parse(payload)
			elif msgId == cls.MSGID_GET_BATVOLT:
				return SimplePWMMsg_GetBatvolt._parse(payload)
			elif msgId == cls.MSGID_BATVOLT:
				return SimplePWMMsg_Batvolt._parse(payload)
			else:
				printError(f"Received unknown message: 0x{msgId:X}")
		return None

	@classmethod
	def _parse(cls, payload):
		return cls()

	def __init__(self, msgId):
		self.msgId = msgId

	def getData(self, payload=None):
		if payload is None:
			payload = b"\x00" * self.PAYLOAD_SIZE
		if len(payload) < self.PAYLOAD_SIZE:
			payload += b"\x00" * (self.PAYLOAD_SIZE - len(payload))
		assert len(payload) == self.PAYLOAD_SIZE
		data = bytearray( (self.MSG_MAGIC, self.msgId, 0, ) ) + payload
		data.append(crc8(data))
		assert len(data) == self.SIZE
		return data

class SimplePWMMsg_Nop(SimplePWMMsg):
	MSGID = SimplePWMMsg.MSGID_NOP

	def __init__(self):
		super().__init__(self.MSGID)

	def __str__(self):
		return f"NOP"

class SimplePWMMsg_Ack(SimplePWMMsg):
	MSGID = SimplePWMMsg.MSGID_ACK

	def __init__(self):
		super().__init__(self.MSGID)

	def __str__(self):
		return f"ACK"

class SimplePWMMsg_Nack(SimplePWMMsg):
	MSGID = SimplePWMMsg.MSGID_NACK

	def __init__(self):
		super().__init__(self.MSGID)

	def __str__(self):
		return f"NACK"

class SimplePWMMsg_Ping(SimplePWMMsg):
	MSGID = SimplePWMMsg.MSGID_PING

	def __init__(self):
		super().__init__(self.MSGID)

	def __str__(self):
		return f"PING"

class SimplePWMMsg_Pong(SimplePWMMsg):
	MSGID = SimplePWMMsg.MSGID_PONG

	def __init__(self):
		super().__init__(self.MSGID)

	def __str__(self):
		return f"PONG"

class SimplePWMMsg_GetControl(SimplePWMMsg):
	MSGID = SimplePWMMsg.MSGID_GET_CONTROL

	def __init__(self):
		super().__init__(self.MSGID)

	def __str__(self):
		return f"GET_CONTROL"

class SimplePWMMsg_Control(SimplePWMMsg):
	MSGID = SimplePWMMsg.MSGID_CONTROL

	MSG_CTLFLG_ANADIS	= 0x01
	MSG_CTLFLG_EEPDIS	= 0x02 #TODO

	@classmethod
	def _parse(cls, payload):
		flags = payload[0]
		return cls(flags)

	def __init__(self, flags):
		super().__init__(self.MSGID)
		self.flags = flags

	def getData(self):
		payload = bytearray( (self.flags, ) )
		return super().getData(payload)

	def __str__(self):
		return f"CONTROL(flags=0x{self.flags:X})"

class SimplePWMMsg_GetSetpoints(SimplePWMMsg):
	MSGID = SimplePWMMsg.MSGID_GET_SETPOINTS

	MSG_GETSPFLG_HSL	= 0x01

	@classmethod
	def _parse(cls, payload):
		flags = payload[0]
		return cls(flags)

	def __init__(self, flags):
		super().__init__(self.MSGID)
		self.flags = flags

	def getData(self):
		payload = bytearray( (self.flags, ) )
		return super().getData(payload)

	def __str__(self):
		return f"GET_SETPOINTS(flags=0x{self.flags:X})"

class SimplePWMMsg_Setpoints(SimplePWMMsg):
	MSGID = SimplePWMMsg.MSGID_SETPOINTS

	MSG_SPFLG_HSL		= 0x01

	@classmethod
	def _parse(cls, payload):
		flags = payload[0]
		nrSp = payload[1]
		if nrSp > 3:
			printError("SimplePWMMsg_Setpoints: Received invalid nr_sp.")
			return None
		setpoints = [ cls._fromLe16(payload[2 + i*2 : 2 + i*2 + 2])
			      for i in range(nrSp) ]
		return cls(flags, setpoints)

	def __init__(self, flags, setpoints):
		super().__init__(self.MSGID)
		self.flags = flags
		self.setpoints = setpoints

	def getData(self):
		payload = bytearray( (self.flags, len(self.setpoints), ) )
		for sp in self.setpoints:
			payload += self._toLe16(sp)
		return super().getData(payload)

	def __str__(self):
		return f"SETPOINTS(flags=0x{self.flags:X}, setpoints={self.setpoints})"

class SimplePWMMsg_GetBatvolt(SimplePWMMsg):
	MSGID = SimplePWMMsg.MSGID_GET_BATVOLT

	def __init__(self):
		super().__init__(self.MSGID)

	def __str__(self):
		return f"GET_BATVOLT"

class SimplePWMMsg_Batvolt(SimplePWMMsg):
	MSGID = SimplePWMMsg.MSGID_BATVOLT

	@classmethod
	def _parse(cls, payload):
		meas = cls._fromLe16(payload[0:2])
		drop = cls._fromLe16(payload[2:4])
		return cls(meas, drop)

	def __init__(self, meas, drop):
		super().__init__(self.MSGID)
		self.meas = meas
		self.drop = drop

	def getData(self):
		payload = bytearray()
		payload += self._toLe16(self.meas)
		payload += self._toLe16(self.drop)
		return super().getData(payload)

	def __str__(self):
		return f"BATVOLT(meas={self.meas}, drop={self.drop})"

class SimplePWM(object):
	FLG_8BIT	= 0x80 # 8-bit data nibble
	FLG_8BIT_UPPER	= 0x40 # 8-bit upper data nibble
	FLG_8BIT_RSV1	= 0x20 # reserved
	FLG_8BIT_RSV0	= 0x10 # reserved

	MSK_4BIT	= 0x0F # data nibble
	MSK_7BIT	= 0x7F

	REMOTE_STANDBY_DELAY_MS	= 5000
	REMOTE_STANDBY_DELAY	= REMOTE_STANDBY_DELAY_MS / 1000

	def __init__(self,
		     port="/dev/ttyUSB0",
		     timeout=1.0,
		     dumpDebugStream=False):
		self.__serial = serial.Serial(port=port,
					      baudrate=19200,
					      bytesize=8,
					      parity=serial.PARITY_NONE,
					      stopbits=2,
					      timeout=timeout)
		self.__timeout = timeout
		self.__dumpDebugStream = dumpDebugStream
		self.__rxByte = 0
		self.__rxBuf = bytearray()
		self.__rxMsgs = deque()
		self.__debugBuf = bytearray()
		self.__nextSyncTime = time.monotonic()
		self.__wakeup()

	def __synchronize(self):
		self.__tx_8bit(SimplePWMMsg.MSG_SYNCBYTE * round(SimplePWMMsg.SIZE * 2.5))

	def __wakeup(self):
		nextSyncTime = self.__nextSyncTime
		self.__nextSyncTime = time.monotonic() + (self.REMOTE_STANDBY_DELAY / 2)
		if time.monotonic() < nextSyncTime:
			return
		count = 0
		while True:
			try:
				self.__synchronize()
				self.ping(0.1)
			except SimplePWMError as e:
				count += 1
				if count > 10:
					raise e
				continue
			break
		printDebug("Connection synchronized.")

	def __tx_8bit(self, data):
		self.__wakeup()
#		printDebug(f"TX: {bytes(data)}")
		for d in data:
			lo = ((d & self.MSK_4BIT) |
			      self.FLG_8BIT)
			hi = (((d >> 4) & self.MSK_4BIT) |
			      self.FLG_8BIT |
			      self.FLG_8BIT_UPPER)
			sendBytes = bytes( (lo, hi) )
			self.__serial.write(sendBytes)

	def __rx_7bit(self, dataByte):
		self.__debugBuf += dataByte
		if dataByte == b"\n":
			if self.__dumpDebugStream:
				text = self.__debugBuf.decode("ASCII", "ignore").strip()
				printDebugDevice(text)
			self.__debugBuf.clear()

	def __rx_8bit(self, dataByte):
#		printDebug(f"RX: {dataByte}")
		d = dataByte[0]
		if d & self.FLG_8BIT_UPPER:
			data = (self.__rxByte & self.MSK_4BIT)
			data |= (d & self.MSK_4BIT) << 4
			self.__rxByte = 0
			self.__rxBuf.append(data)
			if len(self.__rxBuf) >= SimplePWMMsg.SIZE:
				rxMsg = SimplePWMMsg.parse(self.__rxBuf)
				self.__rxBuf.clear()
				if rxMsg:
					self.__rx_message(rxMsg)
		else:
			self.__rxByte = d

	def __tx_message(self, txMsg):
		printDebug("TX msg: " + str(txMsg))
		self.__tx_8bit(txMsg.getData())

	def __rx_message(self, rxMsg):
		printDebug("RX msg: " + str(rxMsg))
		self.__rxMsgs.append(rxMsg)

	def __readNext(self, timeout):
		if timeout < 0:
			timeout = self.__timeout
		self.__serial.timeout = timeout
		dataByte = self.__serial.read(1)
		if dataByte:
			if dataByte[0] & self.FLG_8BIT:
				self.__rx_8bit(dataByte)
			else:
				self.__rx_7bit(dataByte)

	def __waitRxMsg(self, msgType=None, timeout=None):
		if timeout is None:
			timeout = self.__timeout
		timeoutEnd = None
		if timeout >= 0.0:
			timeoutEnd = time.monotonic() + timeout
		retMsg = None
		while retMsg is None:
			if (timeoutEnd is not None and
			    time.monotonic() >= timeoutEnd):
				break
			self.__readNext(timeout)
			if msgType is not None:
				for rxMsg in self.__rxMsgs:
					if isinstance(rxMsg, msgType):
						retMsg = rxMsg
					else:
						printError("Received unexpected "
							   "message: " + str(rxMsg))
			self.__rxMsgs.clear()
		return retMsg

	def dumpDebugStream(self):
		self.__dumpDebugStream = True
		while True:
			self.__waitRxMsg(timeout=-1)

	def ping(self, timeout=None):
		self.__tx_message(SimplePWMMsg_Ping())
		rxMsg = self.__waitRxMsg(SimplePWMMsg_Pong, timeout)
		if not rxMsg:
			raise SimplePWMError("Ping failed.")

	def getAnalogEn(self, timeout=None):
		self.__tx_message(SimplePWMMsg_GetControl())
		rxMsg = self.__waitRxMsg(SimplePWMMsg_Control, timeout)
		if not rxMsg:
			raise SimplePWMError("Failed to get control info.")
		return not (rxMsg.flags & rxMsg.MSG_CTLFLG_ANADIS)

	def setAnalogEn(self, enable, timeout=None):
		self.__tx_message(SimplePWMMsg_GetControl())
		rxMsg = self.__waitRxMsg(SimplePWMMsg_Control, timeout)
		if not rxMsg:
			raise SimplePWMError("Failed to get control info.")
		txMsg = rxMsg
		if enable:
			txMsg.flags &= ~txMsg.MSG_CTLFLG_ANADIS
		else:
			txMsg.flags |= txMsg.MSG_CTLFLG_ANADIS
		self.__tx_message(txMsg)
		rxMsg = self.__waitRxMsg(SimplePWMMsg_Ack, timeout)
		if not rxMsg:
			raise SimplePWMError("Failed to get acknowledge.")

	def getRGB(self, timeout=None):
		self.__tx_message(SimplePWMMsg_GetSetpoints(
				flags=0))
		rxMsg = self.__waitRxMsg(SimplePWMMsg_Setpoints, timeout)
		if not rxMsg:
			raise SimplePWMError("Failed to get RGB setpoints.")
		return rxMsg.setpoints

	def setRGB(self, rgb, timeout=None):
		self.setAnalogEn(False, timeout)
		txMsg = SimplePWMMsg_Setpoints(
				flags=0,
				setpoints=rgb)
		self.__tx_message(txMsg)
		rxMsg = self.__waitRxMsg(SimplePWMMsg_Ack, timeout)
		if not rxMsg:
			raise SimplePWMError("Failed to set RGB setpoints.")

	def getHSL(self, timeout=None):
		self.__tx_message(SimplePWMMsg_GetSetpoints(
				flags=SimplePWMMsg_GetSetpoints.MSG_GETSPFLG_HSL))
		rxMsg = self.__waitRxMsg(SimplePWMMsg_Setpoints, timeout)
		if not rxMsg:
			raise SimplePWMError("Failed to get HSL setpoints.")
		return rxMsg.setpoints

	def setHSL(self, hsl, timeout=None):
		self.setAnalogEn(False, timeout)
		txMsg = SimplePWMMsg_Setpoints(
				flags=SimplePWMMsg_Setpoints.MSG_SPFLG_HSL,
				setpoints=hsl)
		self.__tx_message(txMsg)
		rxMsg = self.__waitRxMsg(SimplePWMMsg_Ack, timeout)
		if not rxMsg:
			raise SimplePWMError("Failed to set RGB setpoints.")

	def getBatVoltage(self, timeout=None):
		self.__tx_message(SimplePWMMsg_GetBatvolt())
		rxMsg = self.__waitRxMsg(SimplePWMMsg_Batvolt, timeout)
		if not rxMsg:
			raise SimplePWMError("Failed to get battery voltage.")
		return (rxMsg.meas, rxMsg.drop)

def parse_setpoints(string):
	s = string.split(",")
	if len(s) != 3:
		raise SimplePWMError("Setpoints are not a comma separated triple.")
	def parseOne(v):
		try:
			v = v.strip()
			if v.casefold().startswith("0x".casefold()):
				# Raw 16 bit hex value.
				v = int(v, 16)
				if not 0 <= v <= 0xFFFF:
					raise SimplePWMError("Setpoint raw value "
						"out of range 0-0xFFFF.")
				return v
			if v.endswith("%"):
				# Percentage
				v = float(v[:-1])
				if not 0.0 <= v <= 100.0:
					raise SimplePWMError("Setpoint percentage "
						"out of range 0%-100%.")
				return round(v * 0xFFFF / 100.0)
			if v.endswith("*"):
				# Degrees (0-360)
				v = float(v[:-1])
				return round((v % 360.0) * 0xFFFF / 360.0)
			# Value 0-255
			v = int(v)
			if 0 <= v <= 0xFF:
				return (v << 8) | v
			raise ValueError
		except ValueError as e:
			raise SimplePWMError("Setpoint value parse error.")
	return [ parseOne(v) for v in s ]

def main():
	try:
		class ArgumentParserOrderedNamespace(argparse.Namespace):
			def __init__(self, *args, **kwargs):
				super().__init__(*args, **kwargs)
				super().__setattr__("_setupDone", False)
				super().__setattr__("_orderedArgs", [])

			def __setattr__(self, name, value):
				sanitizedName = name.replace("_", "-")
				if (not self._setupDone and
				    sanitizedName in (n for n, v in self._orderedArgs)):
					super().__setattr__("_setupDone", True)
					del self._orderedArgs[:]
				self._orderedArgs.append( (sanitizedName, value) )
				super().__setattr__(name, value)

			@property
			def orderedArgs(self):
				return self._orderedArgs if self._setupDone else ()

		p = argparse.ArgumentParser(description="SimplePWM remote control")
		p.add_argument("-p", "--ping", action="store_true",
			       help="Send a ping to the device and wait for pong reply.")
		p.add_argument("-b", "--get-battery", action="store_true",
			       help="Get the battery voltage.")
		p.add_argument("-a", "--get-analog", action="store_true",
			       help="Get the state of the analog inputs.")
		p.add_argument("-A", "--analog", type=int, default=None,
			       help="Enable/disable the analog inputs.")
		p.add_argument("-r", "--get-rgb", action="store_true",
			       help="Get the current RGB setpoint values.")
		p.add_argument("-R", "--rgb", type=parse_setpoints, default=None,
			       help="Set the RGB setpoint values.")
		p.add_argument("-s", "--get-hsl", action="store_true",
			       help="Get the current HSL setpoint values.")
		p.add_argument("-S", "--hsl", type=parse_setpoints, default=None,
			       help="Set the HSL setpoint values.")
		p.add_argument("-W", "--wait", type=float,
			       help="Delay for a fractional number of seconds.")
		p.add_argument("-L", "--loop", action="store_true",
			       help="Repeat the whole sequence.")
		p.add_argument("-d", "--debug-device", action="store_true",
			       help="Read and dump debug messages from device.")
		p.add_argument("-D", "--debug", action="store_true",
			       help="Enable remote side debugging.")
		p.add_argument("port", nargs="?", type=pathlib.Path,
			       default=pathlib.Path("/dev/ttyUSB0"),
			       help="Serial port.")
		args = p.parse_args(namespace=ArgumentParserOrderedNamespace())

		if args.debug:
			printDebug.enabled = True
		if args.debug_device:
			printDebugDevice.enabled = True

		s = SimplePWM(port=str(args.port),
			      dumpDebugStream=args.debug_device)

		repeat = True
		while repeat:
			for name, value in args.orderedArgs:
				if name == "ping":
					if value:
						s.ping()
				if name == "get-battery":
					if value:
						meas, drop = s.getBatVoltage()
						print(f"Battery: "
						      f"measured {meas/1000:.2f} V, "
						      f"drop {drop/1000:.2f} V, "
						      f"actual {(meas+drop)/1000:.2f} V")
				if name == "analog":
					s.setAnalogEn(value)
				if name == "get-analog":
					if value:
						enabled = "enabled" if s.getAnalogEn() else "disabled"
						print(f"Analog inputs state: "
						      f"{enabled}")
				if name == "rgb":
					s.setRGB(value)
				if name == "get-rgb":
					if value:
						rgb = s.getRGB()
						print(f"RGB setpoints: "
						      f"{rgb[0]*100/0xFFFF:.1f}%, "
						      f"{rgb[1]*100/0xFFFF:.1f}%, "
						      f"{rgb[2]*100/0xFFFF:.1f}%")
				if name == "hsl":
					s.setHSL(value)
				if name == "get-hsl":
					if value:
						hsl = s.getHSL()
						print(f"HSL setpoints: "
						      f"{hsl[0]*100/0xFFFF:.1f}%, "
						      f"{hsl[1]*100/0xFFFF:.1f}%, "
						      f"{hsl[2]*100/0xFFFF:.1f}%")
				if name == "wait":
					time.sleep(value)
				if name == "loop":
					repeat = True
					break
			else:
				repeat = False
		if args.debug_device:
			s.dumpDebugStream()
	except SimplePWMError as e:
		printError(e)
		return 1
	except KeyboardInterrupt as e:
		printInfo("Interrupted.")
		return 1

	return 0

sys.exit(main())
